import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

key = random.PRNGKey(0)

grad_tanh = grad(jnp.tanh)
print('grad_tanh(2.0):', grad_tanh(2.0))
